using namespace metal;

struct Matrices
{
    float4x4 viewProjectionMatrix;
    float4x4 viewMatrix;
    float4x4 normalMatrix;
};

struct FragmentListNode
{
    uint    next;
    float   depth;
    uint    color;
};

struct VertexIn
{
    float3 position [[attribute(0)]];
};

struct FragmentIn
{
    float4 position [[position]];
    float3 positionCS;
    //uint sampleIndex [[sample_id]];
};

float srgbToLinear( float s )
{
	if( s <= 0.0404482362771082f )
		return s / 12.92f;
	else
		return pow( (s + 0.055f) / 1.055f, 2.4f );
}

float4 srgbToLinear( float4 color )
{
//	return float4( pow( color.rgb, gamma ), color.a );
	return float4( srgbToLinear( color.r ), srgbToLinear( color.g ), srgbToLinear( color.b ), color.a );
}

float4 FL_UnpackColor( uint packedInput )
{
    float4 unpackedOutput;
    uint4 p = uint4((packedInput & 0xFFU),
                    (packedInput >> 8U) & 0xFFU,
                    (packedInput >> 16U) & 0xFFU,
                    (packedInput >> 24U));
    
    unpackedOutput = ((float4)p) / 255;
    return unpackedOutput;
}
void FL_UnpackDepthAndCoverage( float packedDepthCovg, thread float& depth, thread uint& coverage )
{
    uint uiPackedDepthCovg = as_type<uint>( packedDepthCovg );
    depth = as_type<float>( uiPackedDepthCovg & 0xFFFFFFF0U );
    coverage = uiPackedDepthCovg & 0xFUL;
}

#define FL_NODE_LIST_NULL (0x0U)

#define OIT_NODE_COUNT            8
#define IOT_EMPTY_NODE_DEPTH    1E30f
//Define this to skip compression of nodes closer to eye
#define AOIT_DONT_COMPRESS_FIRST_HALF

//After sorting, this struct represents the transmission value as a function of depth
struct AOITTransmissionFunctionData
{
    float depth[OIT_NODE_COUNT];
    float trans[OIT_NODE_COUNT];
};

int AOITFindInsertIndex( AOITTransmissionFunctionData data, float fragmentDepth, thread float& trans )
{
    const int blockIx = fragmentDepth > data.depth[3] ? 4 : 0;
    
    if( fragmentDepth <= data.depth[blockIx + 0] )
    {
        if( blockIx == 0 )
            trans = 1.0f;
        else
            trans = data.trans[3];
        return blockIx;
    }
    else if( fragmentDepth <= data.depth[blockIx + 1] )
    {
        trans = data.trans[blockIx + 0];
        return blockIx + 1;
    }
    else if( fragmentDepth <= data.depth[blockIx + 2] )
    {
        trans = data.trans[blockIx + 1];
        return blockIx + 2;
    }
    else if( fragmentDepth <= data.depth[blockIx + 3] )
    {
        trans = data.trans[blockIx + 2];
        return blockIx + 3;
    }
    else
    {
        trans = data.trans[blockIx + 3];
        return blockIx + 4;
    }
}
float AOITGetTransmissionAtDepth( AOITTransmissionFunctionData data, float fragmentDepth )
{
    const int blockIx = fragmentDepth > data.depth[3] ? 4 : 0;
    
    if( fragmentDepth <= data.depth[blockIx + 0] )
    {
        if( blockIx == 0 )
            return 1.0f;
        else
            return data.trans[3];
    }
    else if( fragmentDepth <= data.depth[blockIx + 1] )
        return data.trans[blockIx + 0];
    else if( fragmentDepth <= data.depth[blockIx + 2] )
        return data.trans[blockIx + 1];
    else if( fragmentDepth <= data.depth[blockIx + 3] )
        return data.trans[blockIx + 2];
    else
        return data.trans[blockIx + 3];
}

void AOITInsertFragmentNew( float fragmentDepth, float fragmentTrans, thread AOITTransmissionFunctionData& data )
{
    float prevTrans;
    const int index = AOITFindInsertIndex( data, fragmentDepth, prevTrans );
    
    int i;
    bool dataIsFull = data.depth[OIT_NODE_COUNT - 1] != IOT_EMPTY_NODE_DEPTH;
    
    // Make space for the new fragment. Also composite new fragment with the current curve
    // (except for the node that represents the new fragment)
    float depthLast, transLast;
    if( index == OIT_NODE_COUNT )
    {
        //New index follows current items
        depthLast = fragmentDepth;
        transLast = fragmentTrans * prevTrans;
    }
    else
    {
        //Save item that will be shifted out
        depthLast = data.depth[OIT_NODE_COUNT - 1];
        transLast = data.trans[OIT_NODE_COUNT - 1] * fragmentTrans;
        //Shift existing entries down
        for( i = OIT_NODE_COUNT - 2; i >= 0; --i )
        {
            if( index <= i )
            {
                data.depth[i + 1] = data.depth[i];
                data.trans[i + 1] = data.trans[i] * fragmentTrans;
            }
        }
        //Set new entry
        for( i = 0; i < OIT_NODE_COUNT; ++i )
        {
            if( index == i )
            {
                data.depth[i] = fragmentDepth;
                data.trans[i] = fragmentTrans * prevTrans;
            }
        }
    }
    
    //Pack representation if we have too many nodes
    if( dataIsFull )
    {
        const int removalCandidateCount = OIT_NODE_COUNT + 1;
        
#ifdef AOIT_DONT_COMPRESS_FIRST_HALF
        //Favor nodes closest to the eye, skip the first half
        const int startRemovalIdx = removalCandidateCount / 2;
#else
        const int startRemovalIdx = 1;
#endif
        
        float nodeUnderError[removalCandidateCount];
        for( i = startRemovalIdx; i < OIT_NODE_COUNT; ++i )
            nodeUnderError[i] = (data.depth[i] - data.depth[i - 1]) * (data.trans[i - 1] - data.trans[i]);
        
        nodeUnderError[OIT_NODE_COUNT] = (depthLast - data.depth[OIT_NODE_COUNT - 1]) * (data.trans[OIT_NODE_COUNT - 1] - transLast);
        
        //Find the node the generates the smallest removal error
        int smallestErrorIdx = startRemovalIdx;
        float smallestError = nodeUnderError[smallestErrorIdx];
        for( i = startRemovalIdx + 1; i < removalCandidateCount; ++i )
        {
            if( nodeUnderError[i] < smallestError )
            {
                smallestError = nodeUnderError[i];
                smallestErrorIdx = i;
            }
        }
        
        //Remove that node
        for( i = startRemovalIdx - 1; i < OIT_NODE_COUNT - 1; ++i )
        {
            if( i >= smallestErrorIdx - 1 )
                data.trans[i] = data.trans[i + 1];
            
            if( i >= smallestErrorIdx )
                data.depth[i] = data.depth[i + 1];
        }
        
        if( smallestErrorIdx != OIT_NODE_COUNT )
            data.depth[OIT_NODE_COUNT - 1] = depthLast;
        data.trans[OIT_NODE_COUNT - 1] = transLast;
    }
}

float4 AOITResolvePS( device FragmentListNode* nodeList, uint firstNodeOffset, uint sampleIndex )
{
    //Initialize transmission function data
    AOITTransmissionFunctionData data;
    for( uint i = 0; i < OIT_NODE_COUNT; ++i )
    {
        data.depth[i] = IOT_EMPTY_NODE_DEPTH;
        data.trans[i] = 1.0f;
    }
    
    //Get offset to the first node and sample mask of current fragment
    int sampleMask = 1 << sampleIndex;
    
    //Fetch all nodes and add them to transmission function
    uint nodeOffset = firstNodeOffset;
    while( nodeOffset != FL_NODE_LIST_NULL )
    {
        FragmentListNode node = nodeList[ nodeOffset ];
        
        float depth;
        uint coverageMask;
        FL_UnpackDepthAndCoverage( node.depth, depth, coverageMask );
        if( sampleMask & coverageMask )
        {
            //Fragment covered in sample
            float4 nodeColor = FL_UnpackColor( node.color );
            AOITInsertFragmentNew( depth, saturate( 1.0 - nodeColor.w ), data );
        }
        
        nodeOffset = node.next;
    }
    
    //Fetch all nodes again and compute effective fragment color
    float3 color = float3(0, 0, 0);
    nodeOffset = firstNodeOffset;
    while( nodeOffset != FL_NODE_LIST_NULL )
    {
        FragmentListNode node = nodeList[ nodeOffset ];
        
        float depth;
        uint coverageMask;
        FL_UnpackDepthAndCoverage( node.depth, depth, coverageMask );
        if( sampleMask & coverageMask )
        {
            float4 nodeColor = FL_UnpackColor( node.color );
            float trans = AOITGetTransmissionAtDepth( data, depth );
// Not sure what space this is happening in.
            color += nodeColor.xyz * nodeColor.www * trans;
        }
        
        nodeOffset = node.next;
    }
    
    //Net alpha is complement of last transmission function value
    float4 blendColor = float4(color, 1 - data.trans[OIT_NODE_COUNT - 1]);
    return blendColor;
}

vertex FragmentIn vertex_main(VertexIn IN [[stage_in]], constant Matrices& matrices [[buffer(1)]])
{
    FragmentIn OUT;
    OUT.position = matrices.viewProjectionMatrix * float4(IN.position, 1);
    OUT.positionCS = (matrices.viewMatrix * float4(IN.position, 1)).xyz;
    return OUT;
}

fragment float4 fragment_main(FragmentIn IN [[stage_in]], device uint* firstNodeIndex [[buffer(5)]], device FragmentListNode* nodeList [[buffer(3)]], constant uint& viewWidth [[buffer(6)]])
{
    uint pixelCoords = (ushort)IN.position.x + (ushort)IN.position.y * viewWidth;
    uint firstNodeOffset = firstNodeIndex[pixelCoords];
    float4 OUT = AOITResolvePS(nodeList, firstNodeOffset, 0);
    return srgbToLinear( OUT );
}
